-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Keras APIs and fix save-restore issue in TF2 #119
Support Keras APIs and fix save-restore issue in TF2 #119
Conversation
tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py
Show resolved
Hide resolved
@@ -771,7 +816,8 @@ def safe_embedding_lookup_sparse( | |||
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the | |||
default. | |||
default_id: The id to use for an entry with no features. | |||
name: A name for this operation (optional). | |||
name: A name for this operation. Name is optional in graph mode and required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to confirm the name
is set in eager mode?
|
||
# TODO(Lifann) Use more resonable name. | ||
def _create_trainable(trainable_name): | ||
return de.TrainableWrapper(params, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plz check here , there is assume that 'TrainableWrapper' is always appear in the TW name, and the placer won't place TW onto PSs for multi-thread safe . device placement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept. Use ops.colocate_with
to force the TW to reside with ids.
@@ -81,10 +82,11 @@ def __init__(self, params, ids, max_norm, *args, **kwargs): | |||
self.prefetch_values_op = None | |||
self.model_mode = kwargs.get("model_mode") | |||
kwargs.pop("model_mode") | |||
self._slots_track = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_slots_tracked or _tracked_slots maybe better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree
@@ -443,7 +464,8 @@ def embedding_lookup( | |||
params: A dynamic_embedding.Variable instance. | |||
ids: A tensor with any shape as same dtype of params.key_dtype. | |||
partition_strategy: No used, for API compatiblity with `nn.emedding_lookup`. | |||
name: A name for the operation (optional). | |||
name: A name for the operation. Name is optional in graph mode and required | |||
in eager mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a check is needed to make sure the name
is set in eager mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's checked when creating TW. I moved it to the front of func entry now.
Description
This PR makes Keras APIs become compatible to TFRA, and also fixes the save/restore issue in TensorFlow 2.x version.
Thanks to @ccsquare 's suggestions on save/restore.
Main Features
1. Support object-based programming
2. Add a demo on amazon digital video games in object-based APIs.
3. Provide fine-grained API to access non-explicitly defined
dynamic_embedding.Variable
and deprecatedynamic_embedding.GraphKeys
.4. Fix relative bugs: